import pandas as pd
import statsmodels.formula.api as smf
import numpy as np

# Load the new dataset with dep_ratio and hh_type_strat
df = pd.read_csv('cleaned_with_male_head.csv')

# Categorical
df['CHILD'] = df['CHILD'].astype('category')
df['year'] = df['year'].astype('category')
df['H_TYPE'] = df['H_TYPE'].astype('category')
df['hh_type_strat'] = df['hh_type_strat'].astype('category')

# Demean (include dep_ratio)
def demean(df, vars_to_demean, group_col='code_id_hh'):
    df_demean = df.copy()
    for var in vars_to_demean:
        if var in df_demean.columns:
            group_mean = df_demean.groupby(group_col)[var].transform('mean')
            df_demean[var] = df_demean[var] - group_mean
    return df_demean

vars_to_demean = ['fert', 'htype', 'dep_ratio', 'mar', 'emp', 'pri', 'sec', 'high', 'inc', 'fam']
df_demean = demean(df, vars_to_demean)

# CI function
def get_param_info(model, param):
    coeff = model.params.get(param, np.nan)
    se = model.bse.get(param, np.nan)
    if np.isnan(coeff):
        return "N/A"
    ci_low = coeff - 1.96 * se
    ci_high = coeff + 1.96 * se
    return f"{round(coeff, 4)} ({round(ci_low, 4)} to {round(ci_high, 4)})"

# Function to build full table from model
def build_full_table(model, model_name):
    params = [
        'htype',
        'C(CHILD)[T.1]',
        'C(CHILD)[T.2]',
        'C(CHILD)[T.3]',
        'mar',
        'emp',
        'pri',
        'sec',
        'high',
        'inc',
        'fam',
        'Intercept',
        'C(year)[T.2010.0]',
        'C(year)[T.2012.0]',
        'C(year)[T.2014.0]',
        'C(year)[T.2016.0]'
    ]
    variables = [
        'Living with grandparent(s)^a',
        'Sex of children^b: At least 1 son',
        '- At least 1 daughter',
        '- At least 1 son and 1 daughter',
        'Proportion of married women',
        'Proportion of employed women',
        'Proportion of women with primary education',
        ' - Secondary education',
        ' - High school education',
        'Household income',
        'Household size',
        'Constant',
        'Year 2010',
        'Year 2012',
        'Year 2014',
        'Year 2016'
    ]
    data = {
        'Variable': variables,
        model_name: [get_param_info(model, p) for p in params]
    }
    table = pd.DataFrame(data)
    
    # Add footer rows for observations and R-squared
    footer = pd.DataFrame({
        'Variable': ['Observations', 'R-squared'],
        model_name: [int(model.nobs), round(model.rsquared, 3)]
    })
    full_table = pd.concat([table, footer], ignore_index=True)
    return full_table


# Stratified analysis
strat_types = [0, 1, 2]  # 0=Nuclear, 1=3-Gen, 2=Extended
for typ in strat_types:
    df_strat = df_demean[df['hh_type_strat'] == typ].copy()
    if df_strat.empty:
        continue
    formula_strat = 'fert ~ htype + C(CHILD) + mar + emp + pri + sec + high + inc + fam + C(year)'
    model_strat = smf.ols(formula_strat, data=df_strat).fit(cov_type='cluster', cov_kwds={'groups': df_strat['code_id_hh']})
    
    # Full table for each type
    table_strat = build_full_table(model_strat, f'No. of children under 3 ({["Nuclear", "3-Generation", "Extended"][typ]})')
    print(f"\nFull Table for {["Nuclear", "3-Generation", "Extended"][typ]}:\n", table_strat)
    table_strat.to_csv(f'revise_table_5d_{["nuclear", "3gen", "extended"][typ]}.csv', index=False)